package com.yzh.dbrouter;

import com.yzh.dbrouter.annotation.DBRouter;
import com.yzh.dbrouter.strategy.IDBRouterStrategy;
import org.apache.commons.beanutils.BeanUtils;
import org.apache.commons.lang.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Aspect
public class DBRouterJoinPoint {
    private Logger logger = LoggerFactory.getLogger(DBRouterJoinPoint.class);
    private DBRouterConfig dbRouterConfig;
    private IDBRouterStrategy dbRouterStrategy;

    public DBRouterJoinPoint(DBRouterConfig dbRouterConfig, IDBRouterStrategy dbRouterStrategy) {
        this.dbRouterStrategy = dbRouterStrategy;
        this.dbRouterConfig = dbRouterConfig;
    }

    @Pointcut("@annotation(com.yzh.dbrouter.annotation.DBRouter)")
    public void aopPoint() {
    }

    /**
     * 所有需要分库分表的操作，都需要使用自定义注解进行拦截，拦截后读取方法中的入参字段，根据字段进行路由操作
     * dbRouter.key() -> 确定根据哪个字段进行路由
     * getAttrValue(String dbKey, Object[] args) -> 根据原方法入参中找到匹配dbKey的参数，比如dbKey是uId，那么就从入参中找到uId的值
     * dbRouterStrategy.doRouter(dbKeyAttr) -> 路由策略，根据路由值进行处理
     * jp.proceed() -> 执行原方法
     * dbRouterStrategy.clear() -> 这里手动清空ThreadLocal的原因是防止内存泄露
     */
    @Around("aopPoint() && @annotation(dbRouter)")
    public Object doRouter(ProceedingJoinPoint jp, DBRouter dbRouter) throws Throwable {
        String dbKey = dbRouter.key();
        if (StringUtils.isBlank(dbKey) && StringUtils.isBlank(dbRouterConfig.getRouterKey())) {
            throw new RuntimeException("annotation DBRouter key is null!");
        }
        //若@DBRouter没有设置key，那么就走yml配置中的默认配置routerKey: uId
        dbKey = StringUtils.isNotBlank(dbKey) ? dbKey : dbRouterConfig.getRouterKey();
        //路由属性，获取方法入参中的dbKey的值
        String dbKeyAttr = getAttrValue(dbKey, jp.getArgs());
        //路由策略
        dbRouterStrategy.dbRouter(dbKeyAttr);
        //返回结果
        try {
            return jp.proceed();
        } finally {
            dbRouterStrategy.clear();
        }
    }

    public String getAttrValue(String attr, Object[] args) {
        //如果入参只有一个，那么就是路由值
        if (1 == args.length) {
            Object arg = args[0];
            if (arg instanceof String) {
                return (String) arg;
            }
        }
        String filedValue = null;
        //否则遍历寻找入参
        for (Object arg : args) {
            try {
                if (StringUtils.isNotBlank(filedValue)) {
                    break;
                }
                filedValue = BeanUtils.getProperty(arg, attr);
            } catch (Exception e) {
                logger.error("获取路由属性值失败 attr：{}", attr, e);
            }
        }
        return filedValue;
    }
    
}
